#!/usr/bin/env python
# -*- coding: utf-8 -*-

""" By Martin Senande-Rivera
	For Spatial and temporal expansion of global wildland fire activity in response to climate change """

import numpy as np
import pandas as pd
import xarray as xr

### File reading ###
####################

ruta_wfde5='../../DATA/WFDE5/'  # WFDE5 data path
ds_tp_a = xr.open_dataset(ruta_wfde5+'wfde5_tp_a_mean_interp025.nc')
ds_t2m_a = xr.open_dataset(ruta_wfde5+'wfde5_t2m_a_mean_interp025.nc')
ds_tp_m = xr.open_dataset(ruta_wfde5+'wfde5_tp_m_mean_interp025.nc')
ds_t2m_m = xr.open_dataset(ruta_wfde5+'wfde5_t2m_m_mean_interp025.nc')
MAP = ds_tp_a.tp      # Total Precipitation annal mean
MAT = ds_t2m_a.t2m    # 2m Temperature annal mean
Pm = ds_tp_m.tp       # Total Precipitation monthly means
Tm = ds_t2m_m.t2m     # 2m Temperature monthly means

ds_tp_months = xr.open_dataset(ruta_wfde5+'wfde5_tp_months_interp025.nc')
ds_t2m_months = xr.open_dataset(ruta_wfde5+'wfde5_t2m_months_interp025.nc')

P_months = ds_tp_months.tp
T_months = ds_t2m_months.t2m

P_years = P_months.groupby('time.year').sum(dim='time')
T_years = T_months.groupby('time.year').mean(dim='time')

Pmin_years = P_months.groupby('time.year').min(dim='time')
Tmax_years = T_months.groupby('time.year').max(dim='time')

ruta_thres='../1-Threshold_selection/'  # Thresholds.csv data path
ds_kg = xr.open_dataset(ruta_thres+'KG_class.nc')
KG = ds_kg.KG 						# KG classification

### Variable calculation ###
############################

# Precipitation and Tempearture in Oct,Nov,Dec,Jan,Feb,Mar (ONDJFM) and Apr,May,Jun,Jul,Aug,Sep (AMJJAS) 
ONDJFM = [10,11,12,1,2,3]
AMJJAS = [4,5,6,7,8,9]
P_ONDJFM = Pm.sel(month=ONDJFM).sum(dim='month')
P_AMJJAS = Pm.sel(month=AMJJAS).sum(dim='month')
T_ONDJFM = Tm.sel(month=ONDJFM).mean(dim='month')
T_AMJJAS = Tm.sel(month=AMJJAS).mean(dim='month')

# Precipitation in winter (P_winter) and summer (P_summer)
P_winter = xr.full_like(MAP, fill_value=0)
P_summer = xr.full_like(MAP, fill_value=0)
P_winter = xr.where(T_AMJJAS<T_ONDJFM,P_AMJJAS,P_winter)
P_winter = xr.where(T_ONDJFM<T_AMJJAS,P_ONDJFM,P_winter)
P_summer = xr.where(T_AMJJAS>T_ONDJFM,P_AMJJAS,P_summer)
P_summer = xr.where(T_ONDJFM>T_AMJJAS,P_ONDJFM,P_summer)

# Precipitation threshold (P_threshold)
P_treshold = 2*MAT+14.
P_treshold = xr.where(P_winter>MAP*0.7,2*MAT,P_treshold)
P_treshold = xr.where(P_summer>MAP*0.7,2*MAT+28.,P_treshold)

# Temperature of the coldest month (Tmin) and hottest month (Tmax)
Tmin = Tm.min(dim='month')
Tmax = Tm.max(dim='month')

# Total precipitation of the driest month (Pmin) and wettest month (Pmax)
Pmin = Pm.min(dim='month')
Pmax = Pm.max(dim='month')

Pmax_month = Pm.argmax(dim='month',skipna=True)+1

### Thresholds ###
##################

# Tr-ds
pa_fc1=220. 	# 227.1 ± 7.6
pmin_fc1=6. 	# 6.0 ± 0.5
pm_fc1=90. 		# 90.0 ± 1.5

# Ar-dhs
pa_fc2=220. 	# 227.1 ± 7.6
pm_fc2=60. 		# 60.1 ± 0.0
tm_fc2=19.5		# 19.67 ± 0.25

# Te-dhs
pa_fc3=220. 	# 227.1 ± 7.6
pmin_fc3=13. 	# 13.0 ± 0.5
pm_fc3=42. 		# 42.0 ± 1.5
tm_fc3=12.  	# 12.18 ± 0.25

# Bo-hs
pa_fc4=220. 	# 227.1 ± 7.6
pm_fc4=67. 		# 66.7 ± 0.0
tm_fc4=7. 		# 6.79 ± 0.25
tmax_fc4=15. 	# 14.97 ± 0.15

# Auxiliar thresholds
mat_fc2_warm=18.5
dmp_fc2_warm1=6.
dmp_fc2_warm2=10.
mat_fc2_hot=27.5
dmp_fc2_hot1=1.
dmp_fc2_hot2=4.


### Classification ###
######################

# FC_array is a 3D array (month,latitude,longitude) where:
# 0 - No fires
# 1 - Tr-ds fires
# 2 - Ar-fl fires
# 3 - Te-dhs fires
# 4 - Bo-hs fires

# Array with 0 in non fire-prone months and climate numbers (1,2,3,4) in fire-prone months. Shape: time-252,lat-720,lon-1440
FC_array = xr.full_like(P_months, fill_value=0.)
# General climates
A = ((MAP>=P_treshold*10.) & (Tmin>=18.)).values             						# Tropical
B_cold = ((MAP<P_treshold*10.) & (MAT<=mat_fc2_warm)).values 						# Arid cold
B_warm = ((MAP<P_treshold*10.) & (MAT>mat_fc2_warm) & (MAT<=mat_fc2_hot)).values 	# Arid warm
B_hot = ((MAP<P_treshold*10.) & (MAT>mat_fc2_hot)).values 							# Arid hot
C = ((MAP>=P_treshold*10.) & (Tmin<18.) & (MAT>=2.)).values  # Temperate
D = ((MAP>=P_treshold*10.) & (MAT<2.) & (Tmax>10.)).values   # Cold

### Fire Classification
for y in range(1996,2017):
	Pmin_fc1 = P_months.sel(time=str(y)).where(P_months.sel(time=str(y))<=pm_fc1).min(dim='time')
	Pmin_fc3 = P_months.sel(time=str(y)).where((P_months.sel(time=str(y))<=pm_fc3) & (T_months.sel(time=str(y))>=tm_fc3)).min(dim='time')
	Tmax_fc4 = T_months.sel(time=str(y)).where((P_months.sel(time=str(y))<=pm_fc4) & (T_months.sel(time=str(y))>=tm_fc4)).max(dim='time')

	FC1y = ((P_years.sel(year=y)>=pa_fc1) & (Pmin_fc1<=pmin_fc1)).values 
	FC2y = (P_years.sel(year=y)>=pa_fc2).values 
	FC3y = ((P_years.sel(year=y)>=pa_fc3) & (Pmin_fc3<=pmin_fc3)).values
	FC4y = ((P_years.sel(year=y)>=pa_fc4) & (Tmax_fc4>=tmax_fc4)).values

	for m in range(1,13):
		dmP = m-Pmax_month
		dmP = xr.where(dmP<0,dmP+12,dmP)
		FC1m = (P_months.sel(time=str(y)+'-'+str(m))<=pm_fc1).values
		FC2m_cold = ((P_months.sel(time=str(y)+'-'+str(m))<=pm_fc2) & (T_months.sel(time=str(y)+'-'+str(m))>tm_fc2)).values
		FC2m_warm = ((P_months.sel(time=str(y)+'-'+str(m))<=pm_fc2) & (T_months.sel(time=str(y)+'-'+str(m))>tm_fc2) & (dmP>dmp_fc2_warm1) & (dmP<=dmp_fc2_warm2)).values
		FC2m_hot = ((P_months.sel(time=str(y)+'-'+str(m))<=pm_fc2) & (T_months.sel(time=str(y)+'-'+str(m))>tm_fc2) & (dmP>dmp_fc2_hot1) & (dmP<=dmp_fc2_hot2)).values
		FC3m = ((P_months.sel(time=str(y)+'-'+str(m))<=pm_fc3) & (T_months.sel(time=str(y)+'-'+str(m))>tm_fc3)).values
		FC4m = ((P_months.sel(time=str(y)+'-'+str(m))<=pm_fc4) & (T_months.sel(time=str(y)+'-'+str(m))>tm_fc4)).values
		FC_array.sel(time=str(y)+'-'+str(m)).values[(A & FC1y & FC1m)] = 1
		FC_array.sel(time=str(y)+'-'+str(m)).values[(B_cold & FC2y & FC2m_cold)] = 2
		FC_array.sel(time=str(y)+'-'+str(m)).values[(B_warm & FC2y & FC2m_warm)] = 2
		FC_array.sel(time=str(y)+'-'+str(m)).values[(B_hot & FC2y & FC2m_hot)] = 2
		FC_array.sel(time=str(y)+'-'+str(m)).values[(C & FC3y & FC3m)] = 3
		FC_array.sel(time=str(y)+'-'+str(m)).values[(D & FC4y & FC4m)] = 4

# Save FC_array
FC_array.name='FC'
FC_array.to_netcdf('FC_array.nc')

# Array with 0 in non fire-prone years and climate numbers (1,2,3,4) in fire-prone years. Shape: time-21,lat-720,lon-1440
Years = xr.full_like(MAP, fill_value=0.)
FC_years = FC_array.groupby('time.year').max(dim='time')
Years.values = FC_years.where(FC_years>0).count(dim='year')/21.*10.
Years.name='Years_FC'
Years.to_netcdf('Years_FC.nc')

# Array with the mean Potential Fire Season Length in months. Shape: lat-720,lon-1440
PFS = xr.full_like(MAP, fill_value=0.)
PFS_years = FC_array.where(FC_array>0).groupby('time.year').count(dim='time')
PFS_years = xr.where(PFS_years==0,np.nan,PFS_years)
PFS.values[:] = PFS_years.mean(dim='year',skipna=True)
PFS = xr.where(np.isnan(PFS),0.,PFS)
PFS.name='PFS'
PFS.to_netcdf('PFS.nc')

# Array with 0 in non fire-prone points and climate numbers (1,2,3,4) in fire-prone points. Shape: lat-720,lon-1440
FC = xr.full_like(MAP, fill_value=0)
FC = xr.where((KG==1) & (Years>=7.),11,FC)
FC = xr.where((KG==1) & (Years>3.) & (Years<7.),12,FC)
FC = xr.where((KG==1) & (Years>0.) & (Years<=3.),13,FC)
FC = xr.where((KG==2) & (Years>=7.),21,FC)
FC = xr.where((KG==2) & (Years>3.) & (Years<7.),22,FC)
FC = xr.where((KG==2) & (Years>0.) & (Years<=3.),23,FC)
FC = xr.where((KG==3) & (Years>=7.),31,FC)
FC = xr.where((KG==3) & (Years>3.) & (Years<7.),32,FC)
FC = xr.where((KG==3) & (Years>0.) & (Years<=3.),33,FC)
FC = xr.where((KG>=4) & (Years>=7.),41,FC)
FC = xr.where((KG>=4) & (Years>3.) & (Years<7.),42,FC)
FC = xr.where((KG>=4) & (Years>0.) & (Years<=3.),43,FC)
FC.name='FC'
FC.to_netcdf('FC.nc')

